import torch.nn as nn

from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper


class GRU(PytorchSeq2VecWrapper):
    def __init__(self, input_dim, config):
        super(GRU, self).__init__(nn.GRU(input_size=input_dim,
                                         hidden_size=self.get_hidden_size(config['bidirectional']),
                                         num_layers=config['num_layers'],
                                         bidirectional=config['bidirectional'],
                                         dropout=0.5,
                                         batch_first=True))

    def get_hidden_size(self, bidirectional):
        hidden_size = 256
        if bidirectional:
            hidden_size = int(hidden_size / 2)
        return hidden_size
